import json
import numpy as np
from scipy import stats
from sklearn.metrics import mean_squared_error, r2_score
from math import sqrt
import sys
import os

# --- Configuration ---
# Number of independent Large Language Models (LLMs) used in the ensemble.
# The evaluation logic in ``calculate_budget_pearson`` distributes the API
# call budget across these LLMs by first determining how many distinct
# personas can be queried given the current budget.
NUM_LLMS = 3

def pearsonr_ci(x, y, alpha=0.05):
    N = len(x)
    r, p = stats.pearsonr(x, y)
    r_z = np.arctanh(r)
    se = 1/np.sqrt(N-3)
    z = stats.norm.ppf(1-alpha/2)
    lo_z, hi_z = r_z-z*se, r_z+z*se
    lo, hi = np.tanh((lo_z, hi_z))
    return r, p, lo, hi

def extract_persona_predictions(persona_predictions_dict):
    """Extract a list of numeric predictions from the persona_predictions field.
    The function is now robust to multiple output schemas seen in gpt_webaes*.py.
    """
    predictions = []
    if not isinstance(persona_predictions_dict, dict):
        return predictions

    for _persona, response in persona_predictions_dict.items():
        # Most recent schema: {"predictions": [...], "mean_prediction": float}
        if isinstance(response, dict):
            if isinstance(response.get("mean_prediction"), (int, float)):
                predictions.append(float(response["mean_prediction"]))
                continue
            if isinstance(response.get("prediction"), (int, float)):
                predictions.append(float(response["prediction"]))
                continue
            # Fallback: list of individual predictions – take the mean
            if isinstance(response.get("predictions"), list):
                numeric_vals = [float(p) for p in response["predictions"] if isinstance(p, (int, float))]
                if numeric_vals:
                    predictions.append(float(np.mean(numeric_vals)))
                continue
        # If the response itself is numeric, use it directly
        if isinstance(response, (int, float)):
            predictions.append(float(response))
    return predictions

def calculate_metrics(json_file_path, output_file_path): 
    # Load the JSON data
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    
    # Separate English and Foreign websites
    gt_list_en = []
    gt_list_foreign = []
    pred_list_en = []
    pred_list_foreign = []
    
    for d in data:
        # Ground truth score
        if 'ground_truth' in d:
            gt = d['ground_truth']
        else:
            gt = d.get('mean_score')  # fallback if they kept original name

        # --- Obtain model prediction ---
        mean_pred = None  # initialise

        # Priority 1: overall mean prediction key (newer schema)
        if isinstance(d.get('overall_mean_prediction'), (int, float)):
            mean_pred = d['overall_mean_prediction']
        # Priority 2: single mean prediction key (generic prompt schema)
        elif isinstance(d.get('mean_prediction'), (int, float)):
            mean_pred = d['mean_prediction']
        # Priority 3: derive from persona_predictions if available
        elif 'persona_predictions' in d:
            persona_preds = extract_persona_predictions(d['persona_predictions'])
            if persona_preds:
                mean_pred = float(np.mean(persona_preds))
        # Priority 4: legacy key
        if mean_pred is None:
            mean_pred = d.get('mean_score')

        # Skip this datapoint if we still have no valid prediction
        if mean_pred is None:
            continue
            
        image_name = str(d.get('image', '')).lower()
        if 'english' in image_name:
            gt_list_en.append(gt)
            pred_list_en.append(mean_pred)
        else:
            gt_list_foreign.append(gt)
            pred_list_foreign.append(mean_pred)
    
    # Combined data
    gt_combined = gt_list_en + gt_list_foreign
    pred_combined = pred_list_en + pred_list_foreign
    
    # Calculate metrics for each subset
    results = {}
    
    # English metrics
    if len(gt_list_en) > 0:
        corr_en, p_en, lo_en, hi_en = pearsonr_ci(gt_list_en, pred_list_en)
        rmse_en = sqrt(mean_squared_error(gt_list_en, pred_list_en))
        r2_en = r2_score(gt_list_en, pred_list_en)
        
        # Accuracy (threshold > 5)
        gt_en_labels = np.array(gt_list_en) > 5
        pred_en_labels = np.array(pred_list_en) > 5
        acc_en = np.mean(gt_en_labels == pred_en_labels)
        
        # Mean Percentage Error
        pe_en = 100 * np.abs(np.array(gt_list_en) - np.array(pred_list_en)) / np.array(gt_list_en)
        pe_en = np.where(np.isinf(pe_en), np.nan, pe_en)
        pe_en = pe_en[~np.isnan(pe_en)]
        mean_pe_en = np.mean(pe_en)
        
        results['English'] = {
            'correlation': corr_en,
            'correlation_p_value': p_en,
            'correlation_ci_lower': lo_en,
            'correlation_ci_upper': hi_en,
            'rmse': rmse_en,
            'r2_score': r2_en,
            'accuracy': acc_en,
            'mean_percentage_error': mean_pe_en,
            'sample_size': len(gt_list_en)
        }
    
    # Foreign metrics
    if len(gt_list_foreign) > 0:
        corr_foreign, p_foreign, lo_foreign, hi_foreign = pearsonr_ci(gt_list_foreign, pred_list_foreign)
        rmse_foreign = sqrt(mean_squared_error(gt_list_foreign, pred_list_foreign))
        r2_foreign = r2_score(gt_list_foreign, pred_list_foreign)
        
        # Accuracy (threshold > 5)
        gt_foreign_labels = np.array(gt_list_foreign) > 5
        pred_foreign_labels = np.array(pred_list_foreign) > 5
        acc_foreign = np.mean(gt_foreign_labels == pred_foreign_labels)
        
        # Mean Percentage Error
        pe_foreign = 100 * np.abs(np.array(gt_list_foreign) - np.array(pred_list_foreign)) / np.array(gt_list_foreign)
        pe_foreign = np.where(np.isinf(pe_foreign), np.nan, pe_foreign)
        pe_foreign = pe_foreign[~np.isnan(pe_foreign)]
        mean_pe_foreign = np.mean(pe_foreign)
        
        results['Foreign'] = {
            'correlation': corr_foreign,
            'correlation_p_value': p_foreign,
            'correlation_ci_lower': lo_foreign,
            'correlation_ci_upper': hi_foreign,
            'rmse': rmse_foreign,
            'r2_score': r2_foreign,
            'accuracy': acc_foreign,
            'mean_percentage_error': mean_pe_foreign,
            'sample_size': len(gt_list_foreign)
        }
    
    # Combined metrics
    if len(gt_combined) > 0:
        corr_combined, p_combined, lo_combined, hi_combined = pearsonr_ci(gt_combined, pred_combined)
        rmse_combined = sqrt(mean_squared_error(gt_combined, pred_combined))
        r2_combined = r2_score(gt_combined, pred_combined)
        
        # Accuracy (threshold > 5)
        gt_combined_labels = np.array(gt_combined) > 5
        pred_combined_labels = np.array(pred_combined) > 5
        acc_combined = np.mean(gt_combined_labels == pred_combined_labels)
        
        # Mean Percentage Error
        pe_combined = 100 * np.abs(np.array(gt_combined) - np.array(pred_combined)) / np.array(gt_combined)
        pe_combined = np.where(np.isinf(pe_combined), np.nan, pe_combined)
        pe_combined = pe_combined[~np.isnan(pe_combined)]
        mean_pe_combined = np.mean(pe_combined)
        
        results['Combined'] = {
            'correlation': corr_combined,
            'correlation_p_value': p_combined,
            'correlation_ci_lower': lo_combined,
            'correlation_ci_upper': hi_combined,
            'rmse': rmse_combined,
            'r2_score': r2_combined,
            'accuracy': acc_combined,
            'mean_percentage_error': mean_pe_combined,
            'sample_size': len(gt_combined)
        }
    
    # Write concise summary to text file (no per-subset detail)
    with open(output_file_path, 'w') as f:
        f.write("="*80 + "\nSUMMARY\n" + "="*80 + "\n")
        if 'Combined' in results:
            c = results['Combined']
            f.write(
                f"Total Samples: {c['sample_size']}\n"
                f"Correlation: {c['correlation']:.4f}\n"
                f"RMSE: {c['rmse']:.4f}\n"
                f"Accuracy (threshold > 5): {c['accuracy']:.4f}\n"
                f"Mean Percentage Error: {c['mean_percentage_error']:.2f}%\n"
            )
        else:
            # If Combined not available, fall back to listing any subset summaries concisely
            for name, metrics in results.items():
                f.write(
                    f"{name} – N={metrics['sample_size']}, Corr={metrics['correlation']:.4f}, RMSE={metrics['rmse']:.4f}, "
                    f"Acc={metrics['accuracy']:.4f}, MPE={metrics['mean_percentage_error']:.2f}%\n"
                )
    
    return results

def calculate_budget_pearson(json_file_path, output_file_path, max_budget=100, step=10, num_llms=NUM_LLMS, min_budget=3):
    """Calculate Pearson correlation for increasing API call budgets **given multiple independent LLMs**.

    A budget ``B`` represents the total number of API calls that can be made *across* an
    ensemble consisting of ``num_llms`` different Large Language Models (LLMs).  We
    translate this global budget into a *persona* budget by selecting

    >>> P = max(1, B // num_llms)

    personas.  Only the first ``P`` personas (in the order they appear in the JSON
    data) contribute to the final image-level prediction.  Each selected persona
    contributes **exactly one** prediction – its first prediction – to emulate the
    layout where a single persona corresponds to one API call.

    For the *flat* prediction schema (where the JSON contains a top-level
    "predictions": [...] list) we simply take the first ``P`` predictions and
    average them.

    Budgets start from ``min_budget`` (default 3).
    """
    # ------------------------------------------------------------------
    # Prepare budgets: [min_budget, min_budget+step, …, max_budget]
    budgets = list(range(min_budget, max_budget + 1, step))

    with open(json_file_path, 'r') as f:
        data = json.load(f)

    # Accumulate predictions per budget in the same order as ground-truths so
    # we can compute Pearson afterwards.
    predictions_by_budget = {b: [] for b in budgets}
    ground_truths = []

    for entry in data:
        # ---------------- Ground truth ----------------
        if 'ground_truth' in entry:
            gt = entry['ground_truth']
        else:
            gt = entry.get('mean_score')  # fallback key
        if gt is None:
            # Skip if we cannot determine GT for this sample.
            continue
        ground_truths.append(float(gt))

        # ---------------- Prediction schemas ----------------
        # 1. Persona-based: "persona_predictions" : { persona: { predictions: [...] } }
        # 2. Flat list:    "predictions"         : [ ... ]
        persona_dict = entry.get('persona_predictions')
        if isinstance(persona_dict, dict) and len(persona_dict) > 0:
            # ---- Persona-based aggregation ----
            # Cache predictions list per persona (may be empty)
            persona_to_preds = {}
            for persona, resp in persona_dict.items():
                if isinstance(resp, dict) and isinstance(resp.get('predictions'), list):
                    persona_to_preds[persona] = resp['predictions']
                else:
                    persona_to_preds[persona] = []

            ordered_personas = list(persona_to_preds.keys())

            for b in predictions_by_budget.keys():
                # Determine how many personas we can afford with budget ``b``
                p = max(1, b // num_llms)  # at least one persona if budget > 0
                selected_personas = ordered_personas[:p]

                selected_vals = []
                for persona in selected_personas:
                    preds = persona_to_preds.get(persona, [])
                    if preds:
                        selected_vals.append(preds[0])  # first prediction only

                predictions_by_budget[b].append(float(np.mean(selected_vals)) if selected_vals else np.nan)
        else:
            # ---- Flat predictions list ----
            flat_preds = entry.get('predictions')
            if not isinstance(flat_preds, list) or len(flat_preds) == 0:
                for b in predictions_by_budget:
                    predictions_by_budget[b].append(np.nan)
            else:
                for b in predictions_by_budget.keys():
                    p = max(1, b // num_llms)
                    selected = flat_preds[:p]
                    predictions_by_budget[b].append(float(np.mean(selected)) if selected else np.nan)

    # ---------------- Pearson computation per budget ----------------
    budget_results = {}
    for b in budgets:
        preds = predictions_by_budget[b]
        # Align GT & predictions, filtering out NaNs
        filtered_gt = [gt for gt, p in zip(ground_truths, preds) if not np.isnan(p)]
        filtered_pred = [p for p in preds if not np.isnan(p)]
        if len(filtered_gt) < 2:
            budget_results[b] = np.nan  # insufficient data
            continue
        r, _ = stats.pearsonr(filtered_gt, filtered_pred)
        budget_results[b] = r

    # ---------------- Persist results ----------------
    with open(output_file_path, 'w') as f:
        f.write("Budget\tPearson\n")
        for b in budgets:
            r = budget_results[b]
            if np.isnan(r):
                f.write(f"{b}\tNaN\n")
            else:
                f.write(f"{b}\t{r:.4f}\n")

    return budget_results

def process_json_directory(input_dir, output_dir=None):
    """Process all .json files inside a directory and compute metrics for each."""
    if output_dir is None:
        output_dir = input_dir
    os.makedirs(output_dir, exist_ok=True)

    json_files = [f for f in os.listdir(input_dir) if f.lower().endswith('.json')]
    if not json_files:
        print(f"No JSON files found in directory '{input_dir}'.")
        return

    print(f"Found {len(json_files)} JSON files in '{input_dir}'. Processing…")
    for jf in json_files:
        in_path = os.path.join(input_dir, jf)
        out_path = os.path.join(output_dir, os.path.splitext(jf)[0] + '_metrics.txt')
        try:
            calculate_metrics(in_path, out_path)
            # Budget-based Pearson correlations
            budget_out_path = os.path.join(output_dir, os.path.splitext(jf)[0] + '_budget_metrics.txt')
            calculate_budget_pearson(in_path, budget_out_path)
            print(f"Processed {jf} → {os.path.basename(out_path)} | Budget metrics: {os.path.basename(budget_out_path)}")
        except Exception as e:
            print(f"Error processing {jf}: {e}")


def main():
    if len(sys.argv) < 2:
        print("Usage: python metric.py <input_json_or_directory> [output_dir_or_file]\n"\
              "       python metric.py budget <input_json_or_directory> [output_dir_or_file]")
        return

    # --- Determine mode (normal vs budget-only) ---
    budget_only = False
    if sys.argv[1].lower() == 'budget':
        budget_only = True
        if len(sys.argv) < 3:
            print("Error: Missing input path for budget mode.")
            return
        input_path = sys.argv[2]
        remaining_args_offset = 3
    else:
        input_path = sys.argv[1]
        remaining_args_offset = 2

    # -------- Directory input --------
    if os.path.isdir(input_path):
        output_dir = sys.argv[remaining_args_offset] if len(sys.argv) > remaining_args_offset else input_path
        if budget_only:
            # Only compute budget-based correlations for each JSON inside directory
            json_files = [f for f in os.listdir(input_path) if f.lower().endswith('.json')]
            if not json_files:
                print(f"No JSON files found in directory '{input_path}'.")
                return
            for jf in json_files:
                in_path = os.path.join(input_path, jf)
                budget_out_path = os.path.join(output_dir, os.path.splitext(jf)[0] + '_budget_metrics.txt')
                calculate_budget_pearson(in_path, budget_out_path)
                print(f"Budget metrics written for {jf} → {os.path.basename(budget_out_path)}")
        else:
            process_json_directory(input_path, output_dir)
    # -------- Single-file input --------
    else:
        json_file = input_path
        if not os.path.exists(json_file):
            print(f"Error: Input file '{json_file}' not found.")
            return

        if budget_only:
            budget_output_file = sys.argv[remaining_args_offset] if len(sys.argv) > remaining_args_offset else os.path.splitext(json_file)[0] + '_budget_metrics.txt'
            calculate_budget_pearson(json_file, budget_output_file)
            print(f"Budget metrics written to {budget_output_file}")
            return

        output_file = sys.argv[remaining_args_offset] if len(sys.argv) > remaining_args_offset else os.path.splitext(json_file)[0] + '_metrics.txt'
        try:
            results = calculate_metrics(json_file, output_file)
            # Budget-based Pearson correlations
            budget_output_file = os.path.splitext(json_file)[0] + '_budget_metrics.txt'
            calculate_budget_pearson(json_file, budget_output_file)

            print(f"Metrics calculated successfully for {json_file} → {output_file} (budget metrics → {budget_output_file})")
            # Brief summary to console (Combined subset)
            if 'Combined' in results:
                comb = results['Combined']
                print(
                    f"Summary — Samples: {comb['sample_size']} | Correlation: {comb['correlation']:.4f} | "
                    f"RMSE: {comb['rmse']:.4f} | Accuracy: {comb['accuracy']:.4f} | "
                    f"Mean PE: {comb['mean_percentage_error']:.2f}%"
                )
        except Exception as e:
            print(f"Error calculating metrics for {json_file}: {str(e)}")


if __name__ == "__main__":
    main()
